from torch import Tensor
import torch.nn as nn
import math
from functools import partial
from typing import Iterable
import torch
import torch.nn.functional as F


class Imod2(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, a, b):
        return a, b


class Imod(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, *args):
        return args


class Seq(nn.ModuleList):
    def __init__(self, modules: Iterable[nn.Module]):
        super().__init__(modules)

    def forward(self, *args, **kwargs):
        x = self.__getitem__[0](*args, **kwargs)
        for i in range(1, self.__len__()):
            x = self.__getitem__[i](x)
        return x


class Resmod(nn.Module):
    def __init__(self, mod: nn.Module):
        super().__init__()
        self.mod = mod

    def forward(self, x):
        return self.mod(x) + x


class MaxMinNorm(nn.Module):
    def __init__(self, lower_bound: float, upper_bound: float):
        super().__init__()
        self.lower_bound = lower_bound
        self.upper_bound = upper_bound

    def forward(self, x: Tensor, reverse: bool = False):
        if not reverse:
            return (x - self.lower_bound) / (self.upper_bound -
                                             self.lower_bound)
        return x * (self.upper_bound - self.lower_bound) + self.lower_bound


class StdNorm(nn.Module):
    def __init__(self, mean: float, std: float):
        super().__init__()
        self.mean = mean
        self.std = std

    def forward(self, x: Tensor, reverse: bool = False):
        if not reverse:
            return (x - self.mean) / self.std
        return x * self.std + self.mean


class CosineCutoff(nn.Module):
    def __init__(self, rbound_upper=5.0):
        super().__init__()
        self.register_buffer("rbound_upper", torch.tensor(rbound_upper))
        #self.rbound_upper = rbound_upper

    def forward(self, distances):
        ru = self.rbound_upper
        rbounds = 0.5 * \
            (torch.cos(distances * math.pi / ru) + 1.0)
        rbounds = rbounds * (distances < ru).float()
        return rbounds


class ShiftedSoftplus(nn.Module):
    def __init__(self):
        super(ShiftedSoftplus, self).__init__()
        self.shift = torch.log(torch.tensor(2.0)).item()

    def forward(self, x):
        return F.softplus(x) - self.shift


act_fn_dict = {
    "ssp": ShiftedSoftplus,
    "silu": partial(nn.SiLU, inplace=True),
    "relu": partial(nn.ReLU, inplace=True),
    "tanh": nn.Tanh,
    "sigmoid": nn.Sigmoid,
    "selu": partial(nn.SELU, inplace=True),
    "identity": nn.Identity
}
